"""Synthetic graph datasets."""
import math
import os
import pickle
import random

import networkx as nx
import numpy as np

from .. import backend as F
from ..batch import batch
from ..convert import graph
from ..transforms import reorder_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, download, load_graphs, save_graphs


class BAShapeDataset(DGLBuiltinDataset):
    r"""BA-SHAPES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
    <https://arxiv.org/abs/1903.03894>`__

    This is a synthetic dataset for node classification. It is generated by performing the
    following steps in order.

    - Construct a base Barabási–Albert (BA) graph.
    - Construct a set of five-node house-structured network motifs.
    - Attach the motifs to randomly selected nodes of the base graph.
    - Perturb the graph by adding random edges.
    - Nodes are assigned to 4 classes. Nodes of label 0 belong to the base BA graph. Nodes of
      label 1, 2, 3 are separately at the middle, bottom, or top of houses.
    - Generate constant feature for all nodes, which is 1.

    Parameters
    ----------
    num_base_nodes : int, optional
        Number of nodes in the base BA graph. Default: 300
    num_base_edges_per_node : int, optional
        Number of edges to attach from a new node to existing nodes in constructing the base BA
        graph. Default: 5
    num_motifs : int, optional
        Number of house-structured network motifs to use. Default: 80
    perturb_ratio : float, optional
        Number of random edges to add in perturbation divided by the number of edges in the
        original graph. Default: 0.01
    seed : integer, random_state, or None, optional
        Indicator of random number generation state. Default: None
    raw_dir : str, optional
        Raw file directory to store the processed data. Default: ~/.dgl/
    force_reload : bool, optional
        Whether to always generate the data from scratch rather than load a cached version.
        Default: False
    verbose : bool, optional
        Whether to print progress information. Default: True
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access. Default: None

    Attributes
    ----------
    num_classes : int
        Number of node classes

    Examples
    --------

    >>> from dgl.data import BAShapeDataset
    >>> dataset = BAShapeDataset()
    >>> dataset.num_classes
    4
    >>> g = dataset[0]
    >>> label = g.ndata['label']
    >>> feat = g.ndata['feat']
    """

    def __init__(
        self,
        num_base_nodes=300,
        num_base_edges_per_node=5,
        num_motifs=80,
        perturb_ratio=0.01,
        seed=None,
        raw_dir=None,
        force_reload=False,
        verbose=True,
        transform=None,
    ):
        self.num_base_nodes = num_base_nodes
        self.num_base_edges_per_node = num_base_edges_per_node
        self.num_motifs = num_motifs
        self.perturb_ratio = perturb_ratio
        self.seed = seed
        super(BAShapeDataset, self).__init__(
            name="BA-SHAPES",
            url=None,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )

    def process(self):
        g = nx.barabasi_albert_graph(
            self.num_base_nodes, self.num_base_edges_per_node, self.seed
        )
        edges = list(g.edges())
        src, dst = map(list, zip(*edges))
        n = self.num_base_nodes

        # Nodes in the base BA graph belong to class 0
        node_labels = [0] * n
        # The motifs will be evenly attached to the nodes in the base graph.
        spacing = math.floor(n / self.num_motifs)

        for motif_id in range(self.num_motifs):
            # Construct a five-node house-structured network motif
            motif_edges = [
                (n, n + 1),
                (n + 1, n + 2),
                (n + 2, n + 3),
                (n + 3, n),
                (n + 4, n),
                (n + 4, n + 1),
            ]
            motif_src, motif_dst = map(list, zip(*motif_edges))
            src.extend(motif_src)
            dst.extend(motif_dst)

            # Nodes at the middle of a house belong to class 1
            # Nodes at the bottom of a house belong to class 2
            # Nodes at the top of a house belong to class 3
            node_labels.extend([1, 1, 2, 2, 3])

            # Attach the motif to the base BA graph
            src.append(n)
            dst.append(int(motif_id * spacing))
            n += 5

        g = graph((src, dst), num_nodes=n)

        # Perturb the graph by adding non-self-loop random edges
        num_real_edges = g.num_edges()
        max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
        assert (
            self.perturb_ratio <= max_ratio
        ), "perturb_ratio cannot exceed {:.4f}".format(max_ratio)
        num_random_edges = int(num_real_edges * self.perturb_ratio)

        if self.seed is not None:
            np.random.seed(self.seed)
        for _ in range(num_random_edges):
            while True:
                u = np.random.randint(0, n)
                v = np.random.randint(0, n)
                if (not g.has_edges_between(u, v)) and (u != v):
                    break
            g.add_edges(u, v)

        g.ndata["label"] = F.tensor(node_labels, F.int64)
        g.ndata["feat"] = F.ones((n, 1), F.float32, F.cpu())
        self._graph = reorder_graph(
            g,
            node_permute_algo="rcmk",
            edge_permute_algo="dst",
            store_ids=False,
        )

    @property
    def graph_path(self):
        return os.path.join(
            self.save_path, "{}_dgl_graph.bin".format(self.name)
        )

    def save(self):
        save_graphs(str(self.graph_path), self._graph)

    def has_cache(self):
        return os.path.exists(self.graph_path)

    def load(self):
        graphs, _ = load_graphs(str(self.graph_path))
        self._graph = graphs[0]

    def __getitem__(self, idx):
        assert idx == 0, "This dataset has only one graph."
        if self._transform is None:
            return self._graph
        else:
            return self._transform(self._graph)

    def __len__(self):
        return 1

    @property
    def num_classes(self):
        return 4


class BACommunityDataset(DGLBuiltinDataset):
    r"""BA-COMMUNITY dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
    <https://arxiv.org/abs/1903.03894>`__

    This is a synthetic dataset for node classification. It is generated by performing the
    following steps in order.

    - Construct a base Barabási–Albert (BA) graph.
    - Construct a set of five-node house-structured network motifs.
    - Attach the motifs to randomly selected nodes of the base graph.
    - Perturb the graph by adding random edges.
    - Nodes are assigned to 4 classes. Nodes of label 0 belong to the base BA graph. Nodes of
      label 1, 2, 3 are separately at the middle, bottom, or top of houses.
    - Generate normally distributed features of length 10
    - Repeat the above steps to generate another graph. Its nodes are assigned to class
      4, 5, 6, 7. Its node features are generated with a distinct normal distribution.
    - Join the two graphs by randomly adding edges between them.

    Parameters
    ----------
    num_base_nodes : int, optional
        Number of nodes in each base BA graph. Default: 300
    num_base_edges_per_node : int, optional
        Number of edges to attach from a new node to existing nodes in constructing a base BA
        graph. Default: 4
    num_motifs : int, optional
        Number of house-structured network motifs to use in constructing each graph. Default: 80
    perturb_ratio : float, optional
        Number of random edges to add to a graph in perturbation divided by the number of original
        edges in it. Default: 0.01
    num_inter_edges : int, optional
        Number of random edges to add between the two graphs. Default: 350
    seed : integer, random_state, or None, optional
        Indicator of random number generation state. Default: None
    raw_dir : str, optional
        Raw file directory to store the processed data. Default: ~/.dgl/
    force_reload : bool, optional
        Whether to always generate the data from scratch rather than load a cached version.
        Default: False
    verbose : bool, optional
        Whether to print progress information. Default: True
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access. Default: None

    Attributes
    ----------
    num_classes : int
        Number of node classes

    Examples
    --------

    >>> from dgl.data import BACommunityDataset
    >>> dataset = BACommunityDataset()
    >>> dataset.num_classes
    8
    >>> g = dataset[0]
    >>> label = g.ndata['label']
    >>> feat = g.ndata['feat']
    """

    def __init__(
        self,
        num_base_nodes=300,
        num_base_edges_per_node=4,
        num_motifs=80,
        perturb_ratio=0.01,
        num_inter_edges=350,
        seed=None,
        raw_dir=None,
        force_reload=False,
        verbose=True,
        transform=None,
    ):
        self.num_base_nodes = num_base_nodes
        self.num_base_edges_per_node = num_base_edges_per_node
        self.num_motifs = num_motifs
        self.perturb_ratio = perturb_ratio
        self.num_inter_edges = num_inter_edges
        self.seed = seed
        super(BACommunityDataset, self).__init__(
            name="BA-COMMUNITY",
            url=None,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )

    def process(self):
        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)

        # Construct two BA-SHAPES graphs
        g1 = BAShapeDataset(
            self.num_base_nodes,
            self.num_base_edges_per_node,
            self.num_motifs,
            self.perturb_ratio,
            force_reload=True,
            verbose=False,
        )[0]
        g2 = BAShapeDataset(
            self.num_base_nodes,
            self.num_base_edges_per_node,
            self.num_motifs,
            self.perturb_ratio,
            force_reload=True,
            verbose=False,
        )[0]

        # Join them and randomly add edges between them
        g = batch([g1, g2])
        num_nodes = g.num_nodes() // 2
        src = np.random.randint(0, num_nodes, (self.num_inter_edges,))
        dst = np.random.randint(
            num_nodes, 2 * num_nodes, (self.num_inter_edges,)
        )
        src = F.astype(F.zerocopy_from_numpy(src), g.idtype)
        dst = F.astype(F.zerocopy_from_numpy(dst), g.idtype)
        g.add_edges(src, dst)
        g.ndata["label"] = F.cat(
            [g1.ndata["label"], g2.ndata["label"] + 4], dim=0
        )

        # feature generation
        random_mu = [0.0] * 8
        random_sigma = [1.0] * 8

        mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array(
            [0.5] * 2 + random_sigma
        )
        feat1 = np.random.multivariate_normal(mu_1, np.diag(sigma_1), num_nodes)

        mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array(
            [0.5] * 2 + random_sigma
        )
        feat2 = np.random.multivariate_normal(mu_2, np.diag(sigma_2), num_nodes)

        feat = np.concatenate([feat1, feat2])
        g.ndata["feat"] = F.zerocopy_from_numpy(feat)
        self._graph = reorder_graph(
            g,
            node_permute_algo="rcmk",
            edge_permute_algo="dst",
            store_ids=False,
        )

    @property
    def graph_path(self):
        return os.path.join(
            self.save_path, "{}_dgl_graph.bin".format(self.name)
        )

    def save(self):
        save_graphs(str(self.graph_path), self._graph)

    def has_cache(self):
        return os.path.exists(self.graph_path)

    def load(self):
        graphs, _ = load_graphs(str(self.graph_path))
        self._graph = graphs[0]

    def __getitem__(self, idx):
        assert idx == 0, "This dataset has only one graph."
        if self._transform is None:
            return self._graph
        else:
            return self._transform(self._graph)

    def __len__(self):
        return 1

    @property
    def num_classes(self):
        return 8


class TreeCycleDataset(DGLBuiltinDataset):
    r"""TREE-CYCLES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
    <https://arxiv.org/abs/1903.03894>`__

    This is a synthetic dataset for node classification. It is generated by performing the
    following steps in order.

    - Construct a balanced binary tree as the base graph.
    - Construct a set of cycle motifs.
    - Attach the motifs to randomly selected nodes of the base graph.
    - Perturb the graph by adding random edges.
    - Generate constant feature for all nodes, which is 1.
    - Nodes in the tree belong to class 0 and nodes in cycles belong to class 1.

    Parameters
    ----------
    tree_height : int, optional
        Height of the balanced binary tree. Default: 8
    num_motifs : int, optional
        Number of cycle motifs to use. Default: 60
    cycle_size : int, optional
        Number of nodes in a cycle motif. Default: 6
    perturb_ratio : float, optional
        Number of random edges to add in perturbation divided by the
        number of original edges in the graph. Default: 0.01
    seed : integer, random_state, or None, optional
        Indicator of random number generation state. Default: None
    raw_dir : str, optional
        Raw file directory to store the processed data. Default: ~/.dgl/
    force_reload : bool, optional
        Whether to always generate the data from scratch rather than load a cached version.
        Default: False
    verbose : bool, optional
        Whether to print progress information. Default: True
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access. Default: None

    Attributes
    ----------
    num_classes : int
        Number of node classes

    Examples
    --------

    >>> from dgl.data import TreeCycleDataset
    >>> dataset = TreeCycleDataset()
    >>> dataset.num_classes
    2
    >>> g = dataset[0]
    >>> label = g.ndata['label']
    >>> feat = g.ndata['feat']
    """

    def __init__(
        self,
        tree_height=8,
        num_motifs=60,
        cycle_size=6,
        perturb_ratio=0.01,
        seed=None,
        raw_dir=None,
        force_reload=False,
        verbose=True,
        transform=None,
    ):
        self.tree_height = tree_height
        self.num_motifs = num_motifs
        self.cycle_size = cycle_size
        self.perturb_ratio = perturb_ratio
        self.seed = seed
        super(TreeCycleDataset, self).__init__(
            name="TREE-CYCLES",
            url=None,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )

    def process(self):
        if self.seed is not None:
            np.random.seed(self.seed)

        g = nx.balanced_tree(r=2, h=self.tree_height)
        edges = list(g.edges())
        src, dst = map(list, zip(*edges))
        n = nx.number_of_nodes(g)

        # Nodes in the base tree graph belong to class 0
        node_labels = [0] * n
        # The motifs will be evenly attached to the nodes in the base graph.
        spacing = math.floor(n / self.num_motifs)

        for motif_id in range(self.num_motifs):
            # Construct a six-node cycle
            motif_edges = [(n + i, n + i + 1) for i in range(5)]
            motif_edges.append((n + 5, n))
            motif_src, motif_dst = map(list, zip(*motif_edges))
            src.extend(motif_src)
            dst.extend(motif_dst)

            # Nodes in cycles belong to class 1
            node_labels.extend([1] * self.cycle_size)

            # Attach the motif to the base tree graph
            anchor = int(motif_id * spacing)
            src.append(n)
            dst.append(anchor)

            if np.random.random() > 0.5:
                a = np.random.randint(1, 4)
                b = np.random.randint(1, 4)
                src.append(n + a)
                dst.append(anchor + b)

            n += self.cycle_size

        g = graph((src, dst), num_nodes=n)

        # Perturb the graph by adding non-self-loop random edges
        num_real_edges = g.num_edges()
        max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
        assert (
            self.perturb_ratio <= max_ratio
        ), "perturb_ratio cannot exceed {:.4f}".format(max_ratio)
        num_random_edges = int(num_real_edges * self.perturb_ratio)

        for _ in range(num_random_edges):
            while True:
                u = np.random.randint(0, n)
                v = np.random.randint(0, n)
                if (not g.has_edges_between(u, v)) and (u != v):
                    break
            g.add_edges(u, v)

        g.ndata["label"] = F.tensor(node_labels, F.int64)
        g.ndata["feat"] = F.ones((n, 1), F.float32, F.cpu())
        self._graph = reorder_graph(
            g,
            node_permute_algo="rcmk",
            edge_permute_algo="dst",
            store_ids=False,
        )

    @property
    def graph_path(self):
        return os.path.join(
            self.save_path, "{}_dgl_graph.bin".format(self.name)
        )

    def save(self):
        save_graphs(str(self.graph_path), self._graph)

    def has_cache(self):
        return os.path.exists(self.graph_path)

    def load(self):
        graphs, _ = load_graphs(str(self.graph_path))
        self._graph = graphs[0]

    def __getitem__(self, idx):
        assert idx == 0, "This dataset has only one graph."
        if self._transform is None:
            return self._graph
        else:
            return self._transform(self._graph)

    def __len__(self):
        return 1

    @property
    def num_classes(self):
        return 2


class TreeGridDataset(DGLBuiltinDataset):
    r"""TREE-GRIDS dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
    <https://arxiv.org/abs/1903.03894>`__

    This is a synthetic dataset for node classification. It is generated by performing the
    following steps in order.

    - Construct a balanced binary tree as the base graph.
    - Construct a set of n-by-n grid motifs.
    - Attach the motifs to randomly selected nodes of the base graph.
    - Perturb the graph by adding random edges.
    - Generate constant feature for all nodes, which is 1.
    - Nodes in the tree belong to class 0 and nodes in grids belong to class 1.

    Parameters
    ----------
    tree_height : int, optional
        Height of the balanced binary tree. Default: 8
    num_motifs : int, optional
        Number of grid motifs to use. Default: 80
    grid_size : int, optional
        The number of nodes in a grid motif will be grid_size ^ 2. Default: 3
    perturb_ratio : float, optional
        Number of random edges to add in perturbation divided by the
        number of original edges in the graph. Default: 0.1
    seed : integer, random_state, or None, optional
        Indicator of random number generation state. Default: None
    raw_dir : str, optional
        Raw file directory to store the processed data. Default: ~/.dgl/
    force_reload : bool, optional
        Whether to always generate the data from scratch rather than load a cached version.
        Default: False
    verbose : bool, optional
        Whether to print progress information. Default: True
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access. Default: None

    Attributes
    ----------
    num_classes : int
        Number of node classes

    Examples
    --------

    >>> from dgl.data import TreeGridDataset
    >>> dataset = TreeGridDataset()
    >>> dataset.num_classes
    2
    >>> g = dataset[0]
    >>> label = g.ndata['label']
    >>> feat = g.ndata['feat']
    """

    def __init__(
        self,
        tree_height=8,
        num_motifs=80,
        grid_size=3,
        perturb_ratio=0.1,
        seed=None,
        raw_dir=None,
        force_reload=False,
        verbose=True,
        transform=None,
    ):
        self.tree_height = tree_height
        self.num_motifs = num_motifs
        self.grid_size = grid_size
        self.perturb_ratio = perturb_ratio
        self.seed = seed
        super(TreeGridDataset, self).__init__(
            name="TREE-GRIDS",
            url=None,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )

    def process(self):
        if self.seed is not None:
            np.random.seed(self.seed)

        g = nx.balanced_tree(r=2, h=self.tree_height)
        edges = list(g.edges())
        src, dst = map(list, zip(*edges))
        n = nx.number_of_nodes(g)

        # Nodes in the base tree graph belong to class 0
        node_labels = [0] * n
        # The motifs will be evenly attached to the nodes in the base graph.
        spacing = math.floor(n / self.num_motifs)

        # Construct an n-by-n grid
        motif_g = nx.grid_graph([self.grid_size, self.grid_size])
        grid_size = nx.number_of_nodes(motif_g)
        motif_g = nx.convert_node_labels_to_integers(motif_g, first_label=0)
        motif_edges = list(motif_g.edges())
        motif_src, motif_dst = map(list, zip(*motif_edges))
        motif_src, motif_dst = np.array(motif_src), np.array(motif_dst)

        for motif_id in range(self.num_motifs):
            src.extend((motif_src + n).tolist())
            dst.extend((motif_dst + n).tolist())

            # Nodes in grids belong to class 1
            node_labels.extend([1] * grid_size)

            # Attach the motif to the base tree graph
            src.append(n)
            dst.append(int(motif_id * spacing))

            n += grid_size

        g = graph((src, dst), num_nodes=n)

        # Perturb the graph by adding non-self-loop random edges
        num_real_edges = g.num_edges()
        max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
        assert (
            self.perturb_ratio <= max_ratio
        ), "perturb_ratio cannot exceed {:.4f}".format(max_ratio)
        num_random_edges = int(num_real_edges * self.perturb_ratio)

        for _ in range(num_random_edges):
            while True:
                u = np.random.randint(0, n)
                v = np.random.randint(0, n)
                if (not g.has_edges_between(u, v)) and (u != v):
                    break
            g.add_edges(u, v)

        g.ndata["label"] = F.tensor(node_labels, F.int64)
        g.ndata["feat"] = F.ones((n, 1), F.float32, F.cpu())
        self._graph = reorder_graph(
            g,
            node_permute_algo="rcmk",
            edge_permute_algo="dst",
            store_ids=False,
        )

    @property
    def graph_path(self):
        return os.path.join(
            self.save_path, "{}_dgl_graph.bin".format(self.name)
        )

    def save(self):
        save_graphs(str(self.graph_path), self._graph)

    def has_cache(self):
        return os.path.exists(self.graph_path)

    def load(self):
        graphs, _ = load_graphs(str(self.graph_path))
        self._graph = graphs[0]

    def __getitem__(self, idx):
        assert idx == 0, "This dataset has only one graph."
        if self._transform is None:
            return self._graph
        else:
            return self._transform(self._graph)

    def __len__(self):
        return 1

    @property
    def num_classes(self):
        return 2


class BA2MotifDataset(DGLBuiltinDataset):
    r"""BA-2motifs dataset from `Parameterized Explainer for Graph Neural Network
    <https://arxiv.org/abs/2011.04573>`__

    This is a synthetic dataset for graph classification. It was generated by
    performing the following steps in order.

    - Construct 1000 base Barabási–Albert (BA) graphs.
    - Attach house-structured network motifs to half of the base BA graphs.
    - Attach five-node cycle motifs to the rest base BA graphs.
    - Assign each graph to one of two classes according to the type of the attached motif.

    Parameters
    ----------
    raw_dir : str, optional
        Raw file directory to download and store the data. Default: ~/.dgl/
    force_reload : bool, optional
        Whether to reload the dataset. Default: False
    verbose : bool, optional
        Whether to print progress information. Default: True
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access. Default: None

    Attributes
    ----------
    num_classes : int
        Number of graph classes

    Examples
    --------

    >>> from dgl.data import BA2MotifDataset
    >>> dataset = BA2MotifDataset()
    >>> dataset.num_classes
    2
    >>> # Get the first graph and its label
    >>> g, label = dataset[0]
    >>> feat = g.ndata['feat']
    """

    def __init__(
        self, raw_dir=None, force_reload=False, verbose=True, transform=None
    ):
        super(BA2MotifDataset, self).__init__(
            name="BA-2motifs",
            url=_get_dgl_url("dataset/BA-2motif.pkl"),
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )

    def download(self):
        r"""Automatically download data."""
        file_path = os.path.join(self.raw_dir, self.name + ".pkl")
        download(self.url, path=file_path)

    def process(self):
        file_path = os.path.join(self.raw_dir, self.name + ".pkl")
        with open(file_path, "rb") as f:
            adjs, features, labels = pickle.load(f)

        self.graphs = []
        self.labels = F.tensor(labels, F.int64)

        for i in range(len(adjs)):
            g = graph(adjs[i].nonzero())
            g.ndata["feat"] = F.zerocopy_from_numpy(features[i])
            self.graphs.append(g)

    @property
    def graph_path(self):
        return os.path.join(
            self.save_path, "{}_dgl_graph.bin".format(self.name)
        )

    def save(self):
        label_dict = {"labels": self.labels}
        save_graphs(str(self.graph_path), self.graphs, label_dict)

    def has_cache(self):
        return os.path.exists(self.graph_path)

    def load(self):
        self.graphs, label_dict = load_graphs(str(self.graph_path))
        self.labels = label_dict["labels"]

    def __getitem__(self, idx):
        g = self.graphs[idx]
        if self._transform is not None:
            g = self._transform(g)
        return g, self.labels[idx]

    def __len__(self):
        return len(self.graphs)

    @property
    def num_classes(self):
        return 2
